/*********************************************************************
* Rice University Software Distribution License
*
* Copyright (c) 2010, Rice University
* All Rights Reserved.
*
* For a full description see the file named LICENSE.
*
*********************************************************************/

/* Author: Ioan Sucan */

#include <ompl/extensions/ode/ODESimpleSetup.h>
#include <ompl/base/GoalRegion.h>
#include <ompl/config.h>

namespace ob = ompl::base;
namespace og = ompl::geometric;
namespace oc = ompl::control;

/// @cond IGNORE

// Define the goal we want to reach
class CarGoal : public ob::GoalRegion
{
public:

    CarGoal(const ob::SpaceInformationPtr &si, double x, double y) : ob::GoalRegion(si), x_(x), y_(y)
    {
    threshold_ = 0.25;
    }

    virtual double distanceGoal(const ob::State *st) const
    {
    const double *pos = st->as<oc::ODEStateSpace::StateType>()->getBodyPosition(0);
    const double *vel = st->as<oc::ODEStateSpace::StateType>()->getBodyLinearVelocity(0);
    double dx = x_ - pos[0];
    double dy = y_ - pos[1];
    double dot = dx * vel[0] + dy * vel[1];
    if (dot > 0)
        dot = 0;
    else
        dot = sqrt(fabs(dot));

    return sqrt(dx * dx + dy * dy) + dot;
    }

private:

    double x_, y_;

};


// Define how we project a state
class CarStateProjectionEvaluator : public ob::ProjectionEvaluator
{
public:

    CarStateProjectionEvaluator(const ob::StateSpace *space) : ob::ProjectionEvaluator(space)
    {
    }

    virtual unsigned int getDimension(void) const
    {
    return 2;
    }

    virtual void defaultCellSizes(void)
    {
    cellSizes_.resize(2);
    cellSizes_[0] = 1.0;
    cellSizes_[1] = 1.0;
    }

    virtual void project(const ob::State *state, ob::EuclideanProjection &projection) const
    {
    const double *pos = state->as<oc::ODEStateSpace::StateType>()->getBodyPosition(0);
    projection[0] = pos[0];
    projection[1] = pos[1];
    }

};

// Define our own space, to include a distance function we want and register a default projection
class CarStateSpace : public oc::ODEStateSpace
{
public:

    CarStateSpace(const oc::ODEEnvironmentPtr &env) : oc::ODEStateSpace(env)
    {
    }

    virtual double distance(const ob::State *s1, const ob::State *s2) const
    {
    const double *p1 = s1->as<oc::ODEStateSpace::StateType>()->getBodyPosition(0);
    const double *p2 = s1->as<oc::ODEStateSpace::StateType>()->getBodyPosition(0);
    double dx = fabs(p1[0] - p2[0]);
    double dy = fabs(p1[1] - p2[1]);
    return sqrt(dx * dx + dy * dy);
    }

    virtual void registerProjections(void)
    {
        registerDefaultProjection(ob::ProjectionEvaluatorPtr(new CarStateProjectionEvaluator(this)));
    }

};

class CarControlSampler : public oc::RealVectorControlUniformSampler
{
public:

    CarControlSampler(const oc::ControlSpace *cm) : oc::RealVectorControlUniformSampler(cm)
    {
    }

    virtual void sampleNext(oc::Control *control, const oc::Control *previous)
    {
    space_->copyControl(control, previous);
    const ob::RealVectorBounds &b = space_->as<oc::ODEControlSpace>()->getBounds();
    if (rng_.uniform01() > 0.3)
    {
        double &v = control->as<oc::ODEControlSpace::ControlType>()->values[0];
        static const double DT0 = 0.05;
        v += (rng_.uniformBool() ? 1 : -1) * DT0;
        if (v > b.high[0])
        v = b.high[0] - DT0;
        if (v < b.low[0])
        v = b.low[0] + DT0;
    }
    if (rng_.uniform01() > 0.3)
    {
        double &v = control->as<oc::ODEControlSpace::ControlType>()->values[1];
        static const double DT1 = 0.05;
        v += (rng_.uniformBool() ? 1 : -1) * DT1;
        if (v > b.high[1])
        v = b.high[1] - DT1;
        if (v < b.low[1])
        v = b.low[1] + DT1;
    }
    }

    virtual void sampleNext(oc::Control *control, const oc::Control *previous, const ob::State *state)
    {
    sampleNext(control, previous);
    }
};

class CarControlSpace : public oc::ODEControlSpace
{
public:

    CarControlSpace(const ob::StateSpacePtr &m) : oc::ODEControlSpace(m)
    {
    }

    virtual oc::ControlSamplerPtr allocControlSampler(void) const
    {
    return oc::ControlSamplerPtr(new CarControlSampler(this));
    }
};

/// @endcond
